#!/usr/bin/python
'''
airPLS.py Copyright 2014 Renato Lombardo - renato.lombardo@unipa.it
Baseline correction using adaptive iteratively reweighted penalized least squares

This program is a translation in python of the R source code of airPLS version 2.0
by Yizeng Liang and Zhang Zhimin - https://code.google.com/p/airpls
Reference:
Z.-M. Zhang, S. Chen, and Y.-Z. Liang, Baseline correction using adaptive iteratively reweighted penalized least squares. Analyst 135 (5), 1138-1146 (2010).

Description from the original documentation:

Baseline drift always blurs or even swamps signals and deteriorates analytical results, particularly in multivariate analysis.  It is necessary to correct baseline drift to perform further data analysis. Simple or modified polynomial fitting has been found to be effective in some extent. However, this method requires user intervention and prone to variability especially in low signal-to-noise ratio environments. The proposed adaptive iteratively reweighted Penalized Least Squares (airPLS) algorithm doesn't require any user intervention and prior information, such as detected peaks. It iteratively changes weights of sum squares errors (SSE) between the fitted baseline and original signals, and the weights of SSE are obtained adaptively using between previously fitted baseline and original signals. This baseline estimator is general, fast and flexible in fitting baseline.


LICENCE
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>
'''

import numpy as np
from scipy.sparse import csc_matrix, eye, diags
from scipy.sparse.linalg import spsolve


class BaselineCorrection:
    def __init__(self) -> None:
        """
        lambda_: Smoothing parameter for airPLS that can be adjusted by user. The larger lambda is,  the smoother the resulting background, z
        itermax: Maximum number of iterations for airPLS.
        porder: adaptive iteratively reweighted penalized least squares for baseline fitting
        """
        self.lambda_ = 100
        self.itermax = 15
        self.order = 1

    def _whittakerSmooth(self, x, w, lambda_, differences=1):
        '''
        Penalized least squares algorithm for background fitting
        
        input
            x: input data (i.e. chromatogram of spectrum)
            w: binary masks (value of the mask is zero if a point belongs to peaks and one otherwise)
            lambda_: parameter that can be adjusted by user. The larger lambda is,  the smoother the resulting background
            differences: integer indicating the order of the difference of penalties
        
        output
            the fitted background vector
        '''
        X = np.matrix(x)
        m = X.size
        E = eye(m,format='csc')
        for i in range(differences):
            E = E[1:] - E[:-1] # numpy.diff() does not work with sparse matrix. This is a workaround.

        W = diags(w,0,shape=(m,m))
        A = csc_matrix(W+(lambda_*E.T*E))
        B = csc_matrix(W*X.T)
        background = spsolve(A,B)

        return np.array(background)

    def _airPLS(self, x):
        '''
        Adaptive iteratively reweighted penalized least squares for baseline fitting
        
        input
            x: input data (i.e. chromatogram of spectrum)
        
        output
            the fitted background vector
        '''
        m = x.shape[0]
        w = np.ones(m)

        for i in range(1, self.itermax + 1):

            z = self._whittakerSmooth(x, w, self.lambda_, self.order)
            d = x-z
            dssn = np.abs(d[d<0].sum())

            if(dssn<0.001*(abs(x)).sum() or i==self.itermax):
                if(i==self.itermax): print('WARNING max iteration reached!')
                break
            
            w[d>=0] = 0 # d>0 means that this point is part of a peak, so its weight is set to 0 in order to ignore it
            w[d<0] = np.exp(i * np.abs(d[d<0]) / dssn)
            w[0] = np.exp(i * (d[d<0]).max() / dssn) 
            w[-1] = w[0]

        return z

    def __call__(self, sample_ecg):
        """
        Apply airPLS baseline correction to each lead of an ECG signal.
        
        Parameters:
            sample_ecg: A (leads, samples) numpy array of ECG data.
        
        Returns:
            A numpy array of the same shape as sample_ecg, with baselines corrected.
        """
        corrected_ecg = np.zeros_like(sample_ecg)

        for lead in range(sample_ecg.shape[0]):
            corrected_ecg[lead, :] = sample_ecg[lead, :] - self._airPLS(sample_ecg[lead, :])

        return corrected_ecg